CREATE OR REPLACE FUNCTION "pipeline"."kmeans_complex"("source_table" text, "out_table_model" varchar, "out_table_result" varchar, "id_col" varchar, "feature_cols" varchar, "k" int4, "fn_dist" varchar, "agg_centroid" varchar, "max_iter" int4, "min_frac" float8)
  RETURNS "pg_catalog"."text" AS $BODY$
    import json
    import time
    import random
 
    def columnsToVecTable(sourceTable, featureCols, otherCols, outTable):
        featureLists = featureCols.split(",")
        tmpList = []
        for item in featureLists:
            tmpList.append('"{}"'.format(item))
        otherLists = []
        for item in otherCols.split(","):
            otherLists.append('"{}"'.format(item))
        sqlStr = "SELECT madlib.cols2vec('{}', '{}', '{}', NULL, '{}')".format(sourceTable, outTable, ",".join(tmpList), ",".join(otherLists))
        msg = "success"
        msg += sqlStr
        try:
            rs = plpy.execute(sqlStr)
        except Exception as e:
            plpy.execute("DROP TABLE IF EXISTS {}".format(outTable))
            plpy.execute("DROP TABLE IF EXISTS {}_summary".format("outTable"))
            msg = "sql=[{}] errorMsg=[{}]".format(sqlStr, str(e))
            return False, msg
        return True, msg

    def logError(msg, status, result):
        result["status"] = status
        result["error_msg"] = msg
        return json.dumps(result)
        
    def getTableMetaInfo(table_name): 
        tmps = table_name.split(".")
        if len(tmps) > 0:
            table_name = tmps[1]	
        sql_str = "select column_name, data_type from information_schema.columns where table_name = '%s'" %(table_name)
        rs = plpy.execute(sql_str)
        meta = {}
        for line in rs:
            meta[line['column_name']] = line['data_type']
        return meta    
    
    def begin_alg(source_table, out_table_model, out_table_result, id_col, feature_cols, k, fn_dist, agg_centroid, max_iter, min_frac): 
        result = {
            "status": 0,
            "error_msg": "success",
            "result": {
                "input_params": {
                    "source_table": source_table,
                    "out_table_model": out_table_model,
                    "out_table_result": out_table_result,
                    "id_col": id_col,
                    "feature_cols": feature_cols,
                    "k": k,
                    "fn_dist": fn_dist,
                    "agg_centroid": agg_centroid,
                    "max_iter": max_iter,
                    "min_frac": min_frac
                },
                "output_params": [
                ]
            }
        }
        sourceTable = source_table
        if source_table.startswith("create view") or source_table.startswith("CREATE VIEW"):
            try:
                rs = plpy.execute(source_table)
                sourceTable = source_table.split(" ")[2]
            except Exception as e:
                plpy.execute("DROP VIEW IF EXISTS {}".format(sourceTable))
                msg = "sq={}, errorMsg={}".format(source_table, str(e))
                return logError(msg, 500, result)
        randName = random.sample('zyxwvutsrqponmlkjihgfedcba',7);
        matTable = "pipeline.tmp_vector_{}_{}".format("".join(randName), int(time.time()))
        sourceMeta = getTableMetaInfo(sourceTable)
        sourceKeys = set(sourceMeta.keys())
        features = feature_cols.split(",")
        featureKeys = set(features)        
        otherCols = list(sourceKeys - featureKeys)       
        flag, msg = columnsToVecTable(sourceTable, feature_cols, ",".join(otherCols), matTable)
        if not flag:
            if source_table.lower().startswith("create view"):
                plpy.execute("DROP VIEW IF EXISTS {}".format(sourceTable))
            plpy.execute("DROP TABLE IF EXISTS {}".format(matTable))
            plpy.execute("DROP TABLE IF EXISTS {}_summary".format(matTable))
            return logError(msg, 500, result)        
        model_table = out_table_model

        model_sql = """
            CREATE TABLE {} as SELECT * FROM madlib.kmeanspp('{}', '{}', {}, '{}', '{}', {}, {})
        """.format(model_table, matTable, 'feature_vector', k,  fn_dist, agg_centroid, max_iter, min_frac)
        try:
            rs = plpy.execute(model_sql)
            meta = getTableMetaInfo(model_table)
            tmpTable = {
                "out_table_name": model_table,
                "output_cols": meta.keys()
            }
            result["result"]["output_params"].append(tmpTable)
            plpy.execute("DROP TABLE IF EXISTS {}".format(model_table + "_summary"))
            #plpy.execute("DROP TABLE IF EXISTS {}".format(matTable))
        except Exception as e:
            if source_table.lower().startswith("create view"):
                plpy.execute("DROP VIEW IF EXISTS {}".format(sourceTable))
            plpy.execute("DROP TABLE IF EXISTS {}".format(model_table))
            plpy.execute("DROP TABLE IF EXISTS {}_summary".format(matTable))
            plpy.execute("DROP TABLE IF EXISTS {}".format(matTable))
            plpy.execute("DROP TABLE IF EXISTS {}".format(model_table + "_summary"))
            result["status"] = 500
            result["error_msg"] = str(e) + " sql=" + model_sql
            return json.dumps(result)
        #selectBody = [id_col]
        selectBody = []
        for i in range(len(features)):
            selectBody.append('data.feature_vector[{}] as "{}"'.format(i+1, features[i]))
        for col in otherCols:
            selectBody.append('data.{} as {}'.format(col, col))
        resultTable = out_table_result
        resultSql = "CREATE TABLE {} AS SELECT {}, concat((madlib.closest_column(centroids, {}, '{}')).column_id) as cluster_id from {} as data, {} ORDER BY {}".format(resultTable, ",".join(selectBody),  'feature_vector', fn_dist, matTable, model_table, id_col)
        #result["debug_sql"] = resultSql
        try:
            rs = plpy.execute(resultSql)
            meta = getTableMetaInfo(resultTable)
            tmpTable = {
                "out_table_name": resultTable,
                "output_cols": meta.keys()
            }
            result["result"]["output_params"].append(tmpTable)
            plpy.execute("DROP TABLE IF EXISTS {}".format(matTable))
            plpy.execute("DROP TABLE IF EXISTS {}_summary".format(matTable))
            plpy.execute("DROP TABLE IF EXISTS {}".format(model_table + "_summary"))
        except Exception as e:
            if source_table.lower().startswith("create view"):
                plpy.execute("DROP VIEW IF EXISTS {}".format(sourceTable))
            plpy.execute("DROP TABLE IF EXISTS {}".format(model_table))
            plpy.execute("DROP TABLE IF EXISTS {}".format(resultTable))
            plpy.execute("DROP TABLE IF EXISTS {}".format(model_table + "_summary"))
            plpy.execute("DROP TABLE IF EXISTS {}".format(matTable))
            plpy.execute("DROP TABLE IF EXISTS {}_summary".format(matTable))
            result["status"] = 500
            result["error_msg"] = str(e) + " sql=" + resultSql
            return json.dumps(result)
        finally:
            plpy.execute("DROP TABLE IF EXISTS {}".format(matTable))
        if source_table.lower().startswith("create view"):
            plpy.execute("DROP VIEW IF EXISTS {}".format(sourceTable))
        outTables = []
        outTables.append(result["result"]["output_params"][1])
        outTables.append(result["result"]["output_params"][0])
        result["result"]["output_params"] = outTables
        return json.dumps(result)
    
    if __name__ == "__main__":
        ret = begin_alg(source_table, out_table_model, out_table_result, id_col, feature_cols, k, fn_dist, agg_centroid, max_iter, min_frac)
        return ret
	
$BODY$
  LANGUAGE plpythonu VOLATILE
  COST 100